from Network.network import Network
from Network.network_utils import get_inplace_acti
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

def create_layers(inp_dim, out_dim, activation='none', norm=False, use_bias=True, res_layer=False):
    if res_layer: out_dim = int(out_dim / 2)
    if activation == 'crelu': out_dim = int(out_dim / 2)
    layer = [nn.Conv1d(inp_dim, out_dim, 1, bias=use_bias)]
    if norm: layer = [nn.LayerNorm(inp_dim)] + layer
    if activation!= 'none': layer = layer + [get_inplace_acti(activation)]
    return layer

class ResNetwork(Network):    
    def __init__(self, args):
        super().__init__(args)
        targets = [i for i in args.res.residual_layers if i % 2 == 1]
        sources = [i for i in args.res.residual_layers if i % 2 == 0]
        self.residual_dict = dict()
        for t,s in zip(targets, sources):
            if t not in self.residual_dict:
                self.residual_dict[t] = [s]
            else:
                self.residual_dict[t].append(s)
        self.scale_final = args.scale_final
        self.is_crelu = args.activation == "crelu"
        if args.activation_final == "crelu":
            self.activation_final = get_inplace_acti("leakyrelu")
        sizes = [self.num_inputs] + self.hidden_sizes + [self.num_outputs]
        activations = [args.activation for i in range(len(sizes)-2)] + ['none'] # last layer is none
        layers = list()
        for inp_dim, out_dim, acti in zip(sizes, sizes[1:], activations):
            layers += create_layers(inp_dim, out_dim, activation=acti, norm = self.use_layer_norm, use_bias=args.use_bias)
        if args.dropout > 0: # only supports input dropout for now
            layers = [nn.Dropout(args.dropout)] + layers
        self.model = nn.Sequential(*layers)
        self.train()
        self.reset_network_parameters()

    def forward(self, x):
        lc = list()
        for i, l in enumerate(self.model):
            lc.append(x)
            if type(l) == nn.Linear:
                rvals = list()
                if i in self.residual_dict:
                    rvals = [lc[j] for j in self.residual_dict[i]]
                x = torch.cat(rvals + [x], dim=-1)
            x = l(x)                
        x = self.activation_final(x)
        x = x * self.scale_final

        return x